class UnionFind {
    private final int[] p;
    private final int[] size;

    public UnionFind(int n) {
        p = new int[n];
        size = new int[n];
        for (int i = 0; i < n; ++i) {
            p[i] = i;
            size[i] = 1;
        }
    }

    public int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    public boolean union(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa == pb) {
            return false;
        }
        if (size[pa] > size[pb]) {
            p[pb] = pa;
            size[pa] += size[pb];
        } else {
            p[pa] = pb;
            size[pb] += size[pa];
        }
        return true;
    }

    public int size(int root) {
        return size[root];
    }
}

class Solution {
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        boolean[] s = new boolean[n];
        for (int i : initial) {
            s[i] = true;
        }
        UnionFind uf = new UnionFind(n);
        for (int i = 0; i < n; ++i) {
            if (!s[i]) {
                for (int j = i + 1; j < n; ++j) {
                    if (graph[i][j] == 1 && !s[j]) {
                        uf.union(i, j);
                    }
                }
            }
        }
        Set<Integer>[] g = new Set[n];
        Arrays.setAll(g, k -> new HashSet<>());
        int[] cnt = new int[n];
        for (int i : initial) {
            for (int j = 0; j < n; ++j) {
                if (!s[j] && graph[i][j] == 1) {
                    g[i].add(uf.find(j));
                }
            }
            for (int root : g[i]) {
                ++cnt[root];
            }
        }
        int ans = 0, mx = -1;
        for (int i : initial) {
            int t = 0;
            for (int root : g[i]) {
                if (cnt[root] == 1) {
                    t += uf.size(root);
                }
            }
            if (t > mx || (t == mx && i < ans)) {
                ans = i;
                mx = t;
            }
        }
        return ans;
    }
}